import sys, os
import cv2
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torchvision import models
import torch.nn.functional as F
from utils.spec import mel_spectrogram


def calculate_metrics(pred, gold, args, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
    """
    Calculate metrics
    args:
        pred: B x T x C
        gold: B x T
        input_lengths: B (for CTC)
        target_lengths: B (for CTC)
    """
    loss = calculate_loss(pred, gold, args, input_lengths, target_lengths, smoothing, loss_type)
    if loss_type == "ce":
        pred = pred.view(-1, pred.size(2)) # (B*T) x C
        gold = gold.contiguous().view(-1) # (B*T)
        pred = pred.max(1)[1]
        non_pad_mask = gold.ne(args.PAD_TOKEN)
        num_correct = pred.eq(gold)
        num_correct = num_correct.masked_select(non_pad_mask).sum().item()
        return loss, num_correct
    elif loss_type == "ctc":
        return loss, None
    else:
        print("loss is not defined")
        return None, None

def calculate_loss(pred, gold, args, input_lengths=None, target_lengths=None, smoothing=0.0, loss_type="ce"):
    """
    Calculate loss
    args:
        pred: B x T x C
        gold: B x T
        input_lengths: B (for CTC)
        target_lengths: B (for CTC)
        smoothing:
        type: ce|ctc (ctc => pytorch 1.0.0 or later)
        input_lengths: B (only for ctc)
        target_lengths: B (only for ctc)
    """
    if loss_type == "ce":
        pred = pred.view(-1, pred.size(2)) # (B*T) x C
        gold = gold.contiguous().view(-1) # (B*T)
        if smoothing > 0.0:
            eps = smoothing
            num_class = pred.size(1)

            gold_for_scatter = gold.ne(args.PAD_TOKEN).long() * gold
            one_hot = torch.zeros_like(pred).scatter(1, gold_for_scatter.view(-1, 1), 1)
            one_hot = one_hot * (1-eps) + (1-one_hot) * eps / num_class
            log_prob = F.log_softmax(pred, dim=1)

            non_pad_mask = gold.ne(args.PAD_TOKEN)
            num_word = non_pad_mask.sum().item()
            loss = -(one_hot * log_prob).sum(dim=1)
            loss = loss.masked_select(non_pad_mask).sum() / num_word
        else:
            loss = F.cross_entropy(pred, gold, ignore_index=args.PAD_TOKEN, reduction="mean")
    else:
        print("loss is not defined")

    return loss

def get_loss(model, args, benign, benign_lengths, benign_percentages, adv_tgt, adv_tgt_lengths, id2label):

    pred, gold, hyp_seq, gold_seq = model(benign, benign_lengths, adv_tgt, verbose=False)
    hyp_seq = hyp_seq.cpu()
    gold_seq = gold_seq.cpu()

    try:  # handle case for CTC
        strs_gold, strs_hyps = [], []
        for ut_gold in gold_seq:
            str_gold = ""
            for x in ut_gold:
                if int(x) == args.PAD_TOKEN:
                    break
                str_gold = str_gold + id2label[int(x)]
            strs_gold.append(str_gold)
        for ut_hyp in hyp_seq:
            str_hyp = ""
            for x in ut_hyp:
                if int(x) == args.PAD_TOKEN:
                    break
                str_hyp = str_hyp + id2label[int(x)]
            strs_hyps.append(str_hyp)
    except Exception as e:
        print(e)
        sys.exit(0)

    seq_length = pred.size(1)
    sizes = Variable(benign_percentages.mul_(int(seq_length)).int(), requires_grad=False)

    loss, num_correct = calculate_metrics(
        pred, gold, args, input_lengths=sizes, target_lengths=adv_tgt_lengths,
        smoothing=args.label_smoothing, loss_type=args.loss)


    strs_hyps[0] = strs_hyps[0].replace(args.SOS_CHAR, '').replace(args.EOS_CHAR, '')
    strs_gold[0] = strs_gold[0].replace(args.SOS_CHAR, '').replace(args.EOS_CHAR, '')

    if loss.item() == float('Inf'):
        loss = torch.where(loss != loss, torch.zeros_like(loss), loss)  # NaN masking

    return loss, num_correct, strs_hyps[0], strs_gold[0]

class BaseOptimization:
    def __init__(self, args):
        self.cuda = args.cuda
        self.args = args

        self.max_iterations = 1000  # 1000次可以完成95%的优化工作
        self.learning_rate = 0.01
        self.binary_search_steps = 4  # 二分查找最大迭代次数
        self.confidence = 1e2  # confidence的初始值
        self.k = 40  # k值

        # c的初始化边界
        self.lower_bound = 0
        self.upper_bound = 1e10
        # the best l2, score, and image attack
        self.o_bestl2 = 1e10

    def get_adv_loss(self, adv, benign_lengths, benign_percentages, adv_tgt, adv_tgt_lengths, id2label, model):
        if self.cuda:
            adv = adv.cuda()

        # forward
        loss, num_correct, pred_txt, target_txt = get_loss(model, self.args, adv, benign_lengths, benign_percentages, adv_tgt,
                                      adv_tgt_lengths, id2label)

        return loss, num_correct, pred_txt, target_txt

    def _attack_1st_stage(self, original_input, benign_lengths, benign_percentages, adv_tgt, adv_tgt_lengths, id2label, device, model):

        perturbations = torch.zeros_like(original_input).to(device, non_blocking=True).float()

        # Optimization loop
        successful_adv_input = None

        # Initialize rescale
        rescale = 0.05

        for iter_1st_stage_idx in range(1000):
            # Clip the perturbations
            perturbations = torch.clamp(perturbations, -rescale, rescale)
            perturbations.requires_grad = True

            if perturbations.grad is not None:
                perturbations.grad.data.zero_()

            mel_spec = mel_spectrogram(original_input + perturbations, 1024, 80,
                                       16000, 256, 1024, 0, None,
                                       center=True)

            # Call to forward pass
            loss, num_correct, pred_txt, target_txt = get_loss(model, self.args,
                                                               mel_spec.unsqueeze(0).unsqueeze(0),
                                                               benign_lengths,
                                                               benign_percentages, adv_tgt,
                                                               adv_tgt_lengths, id2label)

            # Save the successful perturbations
            if num_correct == adv_tgt.squeeze().shape[0]:
                successful_perturbations = perturbations.clone().detach()

            loss.backward()

            if perturbations.grad is not None:
                perturbations_grad = perturbations.grad
                perturbations.requires_grad = False
                perturbations = perturbations - 0.001 * perturbations_grad.sign()
                perturbations = perturbations.detach()


            print("Iteration {:d}, Text: {:s}/{:s}, Num correct: {:s}/{:s}".format(iter_1st_stage_idx, pred_txt,
                                                                                   target_txt, str(num_correct),
                                                                                   str(adv_tgt.squeeze().shape[0])))

            # Save the best adversarial example and adjust the rescale coefficient if successful
            if iter_1st_stage_idx % 1 == 0:
                if num_correct == adv_tgt.squeeze().shape[0]:
                    # Adjust the rescale coefficient
                    rescale *= 0.8

        return successful_perturbations, rescale
